-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2D Conv transpose support #311
Conversation
gradtest(∇conv_data, rand(10, 10, 3, 2), randn(2, 2, 2, 3))
gradtest(∇conv_data, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 2, 3))
|
65799d1
to
193c4de
Compare
cdaa0b4
to
e86365e
Compare
NNlib counterpart of this PR needs to be merged for the checks to pass |
What's the status of this PR? Working on an InfoGAN model for the Flux model zoo, unable to write a version that would work on SVHN without conv transpose |
It's there anything I can help? I'm waiting for this. |
@tejank10 this PR looks generally good to me, mind just updating it? |
test/tracker.jl
Outdated
@@ -1,7 +1,11 @@ | |||
using Flux | |||
using Flux.Tracker, Test, NNlib | |||
<<<<<<< HEAD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge artificat
src/layers/conv.jl
Outdated
stride = 1, pad = 0, dilation = 1) where {T,N} = | ||
ConvTranspose(σ, w, b, expand.(sub2(Val(N)), (stride, pad, dilation))...) | ||
|
||
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> ConvTranspose((2, 2), 3=>3)
ERROR: UndefVarError: initn not defined
Stacktrace:
[1] ConvTranspose(::Tuple{Int64,Int64}, ::Pair{Int64,Int64}, ::Function) at /home/vchuravy/.julia/packages/Flux/hguaX/src/layers/conv.jl:84 (repeats 2 times)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also:
julia> ConvTranspose((2, 2), 3=>64; init=Flux.glorot_uniform)(rand(4, 4, 3, 10))
ERROR: MethodError: no method matching ∇conv_data(::Array{Float64,4}, ::Array{Float32,4}; stride=(1, 1), pad=(0, 0), dilation=(1, 1))
Closest candidates are:
∇conv_data(::AbstractArray, ::TrackedArray; kw...) at /home/vchuravy/.julia/packages/Flux/hguaX/src/tracker/lib/array.jl:390
∇conv_data(::TrackedArray, ::AbstractArray; kw...) at /home/vchuravy/.julia/packages/Flux/hguaX/src/tracker/lib/array.jl:391
∇conv_data(::A<:AbstractArray, ::A<:AbstractArray; size, pad, stride, dilation, flipkernel) where A<:AbstractArray at /home/vchuravy/.julia/packages/NNlib/nf8OC/src/conv.jl:74
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing these out, I've fixed them in the latest commit.
Thanks! I tried to use it again and all seems to work fine on the CPU, but in the GPU case I get:
|
I just made a PR in CuArrays (#223) to reflect the changes in |
This looks like it never got merged. Is there no ConvTranspose in Flux? |
@MikeInnes if #54 is fine then let's get it, and this PR merged? |
src/layers/conv.jl
Outdated
@@ -77,6 +125,7 @@ struct DepthwiseConv{N,F,A,V} | |||
bias::V | |||
stride::NTuple{N,Int} | |||
pad::NTuple{N,Int} | |||
dilation::NTuple{N,Int} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like perhaps an unrelated change?
In my limited autoencoder testing, this seems to be working, but I'm worried about the failing gradient tests. @tejank10 do you know why they're failing? |
Trying tests with |
I just merged that PR. Worth testing this again, figuring out if we need a CuArrays tag etc. |
The |
You should be able to add NNlib master to the manifest and get it tested that way. |
Awesome stuff @tejank10, thanks! |
Thanks for working on this. It will be useful in implementing GANs using Flux. |
Also thanks a lot to @staticfloat for the review! |
Does this support only 2D images, as the name suggests, or 1D and 3D as well? |
It should support 1D and 3D as well. |
This PR adds support for 2D Conv Transpose for Flux, along with #54 in NNlib.jl